"""
Copyright (c) 2023 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import pytest
import torch
from tests.test_helpers.jit_utils import (
    gen_decode_attention_modules,
    gen_prefill_attention_modules,
)

import flashinfer
from flashinfer.utils import has_flashinfer_jit_cache


@pytest.fixture(
    autouse=not has_flashinfer_jit_cache(),
    scope="module",
)
def warmup_jit():
    flashinfer.jit.build_jit_specs(
        gen_decode_attention_modules(
            [torch.float16],  # q_dtypes
            [torch.float16],  # kv_dtypes
            [128, 256],  # head_dims
            [0],  # pos_encoding_modes
            [False],  # use_sliding_windows
            [False],  # use_logits_soft_caps
        )
        + gen_prefill_attention_modules(
            [torch.float16],  # q_dtypes
            [torch.float16],  # kv_dtypes
            [128, 256],  # head_dims
            [0],  # pos_encoding_modes
            [False],  # use_sliding_windows
            [False],  # use_logits_soft_caps
            [False],  # use_fp16_qk_reductions
        ),
        verbose=False,
    )
    yield


def ceil_div(a, b):
    return (a + b - 1) // b


@pytest.mark.parametrize("stage", ["decode", "append"])
@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("unique_kv_len", [37, 17])
@pytest.mark.parametrize("shared_kv_len", [128, 512, 2048])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("page_size", [1, 16])
def test_batch_attention_with_shared_prefix_paged_kv_cache(
    stage,
    batch_size,
    unique_kv_len,
    shared_kv_len,
    num_heads,
    causal,
    head_dim,
    page_size,
):
    if stage == "decode" and causal:
        pytest.skip("Causal attention is not required in decode stage")
    assert shared_kv_len % page_size == 0
    kv_layout = "NHD"
    if stage == "append":
        q = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half()
        q_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len
    else:
        q = torch.randn(batch_size, num_heads, head_dim).to(0).half()
        q_indptr = torch.arange(0, batch_size + 1).to(0).int()
    k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
    v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
    k_unique = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half()
    v_unique = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half()

    kv_data = (
        torch.zeros(
            ceil_div(shared_kv_len, page_size)
            + batch_size * ceil_div(unique_kv_len, page_size),
            2,
            page_size,
            num_heads,
            head_dim,
        )
        .to(0)
        .half()
    )
    shared_kv_indices = torch.arange(0, ceil_div(shared_kv_len, page_size)).to(0).int()
    shared_append_indptr = torch.arange(0, 2).to(0).int() * shared_kv_len
    shared_kv_indptr = torch.arange(0, 2).to(0).int() * ceil_div(
        shared_kv_len, page_size
    )
    shared_last_page_len = torch.full(
        (1,), (shared_kv_len - 1) % page_size + 1, dtype=torch.int32
    ).to(0)
    flashinfer.append_paged_kv_cache(
        k_shared,
        v_shared,
        *flashinfer.get_batch_indices_positions(
            shared_append_indptr,
            flashinfer.get_seq_lens(shared_kv_indptr, shared_last_page_len, page_size),
            k_shared.shape[0],
        ),
        kv_data,
        shared_kv_indices,
        shared_kv_indptr,
        shared_last_page_len,
        kv_layout,
    )
    unique_kv_indices = torch.arange(
        0, batch_size * ceil_div(unique_kv_len, page_size)
    ).to(0).int() + ceil_div(shared_kv_len, page_size)
    unique_append_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len
    unique_kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * ceil_div(
        unique_kv_len, page_size
    )
    unique_last_page_len = torch.full(
        (batch_size,), (unique_kv_len - 1) % page_size + 1, dtype=torch.int32
    ).to(0)
    flashinfer.append_paged_kv_cache(
        k_unique,
        v_unique,
        *flashinfer.get_batch_indices_positions(
            unique_append_indptr,
            flashinfer.get_seq_lens(unique_kv_indptr, unique_last_page_len, page_size),
            k_unique.shape[0],
        ),
        kv_data,
        unique_kv_indices,
        unique_kv_indptr,
        unique_last_page_len,
        kv_layout,
    )

    if stage == "decode":
        multi_level_wrapper = flashinfer.MultiLevelCascadeAttentionWrapper(
            2, torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout
        )
        shared_prefix_decode_wrapper = (
            flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper(
                torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout
            )
        )
    else:
        multi_level_wrapper = flashinfer.MultiLevelCascadeAttentionWrapper(
            2, torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout
        )
        shared_prefix_prefill_wrapper = (
            flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper(
                torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout
            )
        )

    qo_indptr_top = torch.tensor([0, q.shape[0]], dtype=torch.int32).to(0)
    if stage == "decode":
        qo_indptr_bottom = torch.arange(0, batch_size + 1, dtype=torch.int32).to(0)
        multi_level_wrapper.plan(
            [qo_indptr_top, qo_indptr_bottom],
            [shared_kv_indptr, unique_kv_indptr],
            [shared_kv_indices, unique_kv_indices],
            [shared_last_page_len, unique_last_page_len],
            num_heads,
            num_heads,
            head_dim,
            page_size,
        )
        o_multi_level = multi_level_wrapper.run(q, kv_data)
    else:
        qo_indptr_bottom = (
            torch.arange(0, batch_size + 1, dtype=torch.int32).to(0) * unique_kv_len
        )
        multi_level_wrapper.plan(
            [qo_indptr_top, qo_indptr_bottom],
            [shared_kv_indptr, unique_kv_indptr],
            [shared_kv_indices, unique_kv_indices],
            [shared_last_page_len, unique_last_page_len],
            num_heads,
            num_heads,
            head_dim,
            page_size,
            causal=causal,
        )
        o_multi_level = multi_level_wrapper.run(q, kv_data)

    if stage == "decode":
        shared_prefix_decode_wrapper.begin_forward(
            unique_kv_indptr,
            unique_kv_indices,
            unique_last_page_len,
            num_heads,
            num_heads,
            head_dim,
            page_size,
        )
        o_two_level = shared_prefix_decode_wrapper.forward(
            q, k_shared, v_shared, kv_data
        )
    else:
        shared_prefix_prefill_wrapper.begin_forward(
            q_indptr,
            unique_kv_indptr,
            unique_kv_indices,
            unique_last_page_len,
            num_heads,
            num_heads,
            head_dim,
            page_size,
        )
        o_two_level = shared_prefix_prefill_wrapper.forward(
            q, k_shared, v_shared, kv_data, causal=causal
        )

    torch.testing.assert_close(o_multi_level, o_two_level, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize("seed", [0])
@pytest.mark.parametrize("num_tries", [50])
def test_merge_state_in_place_with_mask(seed, num_tries):
    seq_len = 512
    num_heads = 32
    head_dim = 128
    va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
    sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
    vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
    sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
    va_orginal = va.clone()
    sa_original = sa.clone()

    # No mask.
    flashinfer.merge_state_in_place(va, sa, vb, sb)
    va_merged_ref = va.clone()
    sa_merged_ref = sa.clone()
    assert not torch.allclose(va_merged_ref, va_orginal)
    assert not torch.allclose(sa_merged_ref, sa_original)

    # Mask with all 1s. Should be identical to no mask.
    mask = torch.ones(seq_len, dtype=torch.bool).to("cuda:0")
    va = va_orginal.clone()
    sa = sa_original.clone()
    flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask)
    va_merged = va
    sa_merged = sa
    torch.testing.assert_close(va_merged, va_merged_ref, rtol=1e-3, atol=1e-3)
    torch.testing.assert_close(sa_merged, sa_merged_ref, rtol=1e-3, atol=1e-3)

    # Mask with all zeros. Input and output should be identical.
    mask = torch.zeros(seq_len, dtype=torch.bool).to("cuda:0")
    va = va_orginal.clone()
    sa = sa_original.clone()
    flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask)
    va_merged = va
    sa_merged = sa
    torch.testing.assert_close(va_merged, va_orginal, rtol=1e-3, atol=1e-3)
    torch.testing.assert_close(sa_merged, sa_original, rtol=1e-3, atol=1e-3)

    # Test some random masks.
    randgen = torch.Generator(device="cuda:0")
    randgen.manual_seed(seed)
    for _ in range(num_tries):
        rand_mask = (
            torch.rand(seq_len, generator=randgen, dtype=torch.float32, device="cuda:0")
            > 0.5
        ).to(dtype=torch.bool)
        true_indices = rand_mask.nonzero()
        false_indices = (rand_mask == 0).nonzero()
        va = va_orginal.clone()
        sa = sa_original.clone()
        flashinfer.merge_state_in_place(va, sa, vb, sb, mask=rand_mask)
        va_merged = va
        sa_merged = sa

        torch.testing.assert_close(
            va_merged[false_indices],
            va_orginal[false_indices],
            rtol=1e-3,
            atol=1e-3,
        )
        torch.testing.assert_close(
            sa_merged[false_indices],
            sa_original[false_indices],
            rtol=1e-3,
            atol=1e-3,
        )
        torch.testing.assert_close(
            va_merged[true_indices],
            va_merged_ref[true_indices],
            rtol=1e-3,
            atol=1e-3,
        )
        torch.testing.assert_close(
            sa_merged[true_indices],
            sa_merged_ref[true_indices],
            rtol=1e-3,
            atol=1e-3,
        )


if __name__ == "__main__":
    test_batch_attention_with_shared_prefix_paged_kv_cache(
        "decode", 12, 37, 128, 8, False, 128, 16
    )
    test_batch_attention_with_shared_prefix_paged_kv_cache(
        "append", 12, 37, 128, 8, True, 128, 16
    )
